FedProto: Federated Prototype Learning across Heterogeneous Clients
Abstract
Heterogeneity across clients in federated learning (FL) usually hinders the optimization convergence and generalization performance when the aggregation of clients’ knowledge occurs in the gradient space. For example, clients may differ in terms of data distribution, network latency, input/output space, and/or model architecture, which can easily lead to the misalignment of their local gradients. To improve the tolerance to heterogeneity, we propose a novel federated prototype learning (FedProto) framework in which the clients and server communicate the abstract class prototypes instead of the gradients. FedProto aggregates the local prototypes collected from different clients, and then sends the global prototypes back to all clients to regularize the training of local models. The training on each client aims to minimize the classification error on the local data while keeping the resulting local prototypes sufficiently close to the corresponding global ones. Moreover, we provide a theoretical analysis to the convergence rate of FedProto under non-convex objectives. In experiments, we propose a benchmark setting tailored for heterogeneous FL, with FedProto outperforming several recent FL approaches on multiple datasets.
Introduction
Federated learning (FL) is widely used in multiple applications to enable collaborative learning across a variety of clients without sharing private data. It aims at training a global model on a centralized server while all data are distributed over many local clients and cannot be freely transmitted for privacy or communication concerns (McMahan et al. 2017). The iterative process of FL has two steps: (1) each local client is synchronized by the global model and then trained using its local data; and (2) the server updates the global model by aggregating all the local models. Considering that the model aggregation occurs in the gradient space, traditional FL still has some practical challenges caused by the heterogeneity of data and model (Kairouz et al. 2019). Efficient algorithms suitable to overcome both these two challenges have not yet been fully developed or systematically examined.
To tackle the statistical heterogeneity of data distributions, one straightforward solution is to maintain multiple global models for different local distributions, e.g., the works for clustered FL (Sattler, Müller, and Samek 2020). Another widely studied strategy is personalized FL (Tan et al. 2021) where a personalized model is generated for each client by leveraging both global and local information. Nevertheless, most of these methods depend on gradient-based aggregation, resulting in high communication costs and heavy reliance on homogeneous local models.
However, in real-world applications, model heterogeneity is common because of varying hardware and computation capabilities across clients (Long et al. 2020). Knowledge Distillation (KD)-based FL (Lin et al. 2020) addresses this challenge by transferring the teacher model’s knowledge to student models with different model architectures. However, these methods require an extra public dataset to align the student and teacher models’ outputs, increasing the computation costs. Moreover, the performance of KD-based FL can significantly degrade with the increase in the distribution divergence between the public dataset and on-client datasets that are usually non-IID.
Inspired by prototype learning, merging the prototypes over heterogeneous datasets can effectively integrate the feature representations from diverse data distributions (Snell, Swersky, and Zemel 2017; Liu et al. 2020; Dvornik, Schmid, and Mairal 2020). On-client intelligent agents in the FL system can share knowledge by exchanging information in terms of representations, despite statistical and model heterogeneity (Cai et al. 2020; Li et al. 2021). For example, when we talk about “dog”, different people will have a unique “imagined picture” or “prototype” to represent the concept “dog”. Their prototypes may be slightly diverse due to different life experience and visual memory. Exchanging these concept-specific prototypes across people enables them to acquire more knowledge about the concept “dog”. Treating each FL client as a human-like intelligent agent, the core idea of our method is to exchange prototypes rather than share model parameters or raw data, which can naturally match the knowledge acquisition behavior of humans.
In this paper, we propose a novel prototype aggregation-based FL framework where only prototypes are transmitted between the server and clients. The proposed solution does not require model parameters or gradients to be aggregated, so it has a huge potential to be a robust framework for various heterogeneous FL scenarios. Concretely, each client can have different model architectures and input/output spaces, but they can still exchange information by sharing prototypes. Each abstract prototype represents a class by the mean representations transformed from the observed samples belonging to the same class. Aggregating the prototypes allows for efficient communication across heterogeneous clients.
Our main contributions can be summarized as follows:
-
•
We propose a benchmark setting tailored for heterogeneous FL that considers a more general heterogeneous scenario across local clients.
-
•
We present a novel FL method that significantly improves the communication efficiency in the heterogeneous setting. To the best of our knowledge, we are the first to propose prototype aggregation-based FL.
-
•
We theoretically provide a convergence guarantee for our method and carefully derive the convergence rate under non-convex conditions.
-
•
Extensive experiments show the superiority of our proposed method in terms of communication efficiency and test performance in several benchmark datasets.
Related Work
Heterogeneous Federated Learning
Statistical heterogeneity across clients (also known as the non-IID problem) is the most important challenge of FL. FedProx (Li et al. 2020) proposed a local regularization term to optimize each client’s local model. Some recent studies (Arivazhagan et al. 2019; Liang et al. 2020; Deng, Kamani, and Mahdavi 2020) train personalized models to leverage both globally shared information and the personalized part (Tan et al. 2021; Jiang, Ji, and Long 2020). The third solution is to provide multiple global models by clustering the local models (Mansour et al. 2020; Ghosh et al. 2020; Sattler, Müller, and Samek 2020) into multiple groups or clusters. Recently, self-supervised learning strategies are incorporated into the local training phase to handle the heterogeneity challenges (Li, He, and Song 2021; Liu et al. 2021a; Yang et al. 2021). (Fallah, Mokhtari, and Ozdaglar 2020) applies meta-training strategy for personalized FL.
Heterogeneous model architecture is another major challenging scenario of FL. The recently proposed KD-based FL (Lin et al. 2020; Jeong et al. 2018; Li and Wang 2020; Long et al. 2021) can serve as an alternative solution to address this challenge. In particular, with the assumption of adding a shared toy dataset in the federated setting, these KD-based FL methods can distill knowledge from a teacher model to student models with different model architectures. Some recent studies have also attempted to combine the neural architecture search with federated learning (Zhu, Zhang, and Jin 2020; He, Annavaram, and Avestimehr 2020; Singh et al. 2020), which can be applied to discover a customized model architecture for each group of clients with different hardware capabilities and configurations. A collective learning platform is proposed to handle heterogeneous architectures without access to the local training data and architectures in (Hoang et al. 2019). Moreover, functionality-based neural matching across local models (Wang et al. 2020a) can aggregate neurons with similar functionality regardless of the variance of the model architectures.
However, most of these mentioned FL methods focus on only one heterogeneous challenging scenario. All of them use gradient-based aggregation methods which will raise concerns about communication efficiency and gradient-based attacks (Zhu, Liu, and Han 2019; Chen et al. 2020; Liu et al. 2021b; Zheng et al. 2021).
Prototype Learning
The concept of prototypes (the mean of multiple features) has been explored in a variety of tasks. In image classification, a prototype can be a proxy of a class and is calculated as the mean of the feature vectors within every class (Snell, Swersky, and Zemel 2017). In action recognition, the features of a video in different timestamps can be averaged to serve as the representation of the video (Simonyan and Zisserman 2014; Xue et al. 2021). Aggregated local features can serve as descriptors for image retrieval (Babenko and Lempitsky 2015). Averaging word embeddings as the representation of a sentence can achieve competitive performance on multiple NLP benchmarks (Wieting et al. 2015). The authors in (Hoang et al. 2020) use prototypes to represent task-agnostic information in distributed machine learning and propose a new fusion paradigm to integrate those prototypes to generate a new model for a new task. In (Michieli and Ozay 2021), prototype margins are used to optimize visual feature representations for FL. In our paper, we borrow the concept of prototypes to represent one class and apply prototype aggregation in the setting of heterogeneous FL.
In general, prototypes are widely used in learning scenarios with a limited number of training samples (Snell, Swersky, and Zemel 2017). This learning scenario is consistent with the latent assumption of cross-client FL: that each client has a limited number of instances to independently train a model with the desired performance. The assumption has been widely supported by the FL-based benchmark datasets (Caldas et al. 2018; He et al. 2020) and in related applications, such as healthcare (Rieke et al. 2020; Xu et al. 2020) and street image object detection (Luo et al. 2019).
Problem Setting
Heterogeneous Federated Learning Setting
In federated learning, each client owns a local private dataset drawn from distribution , where and denote the input features and corresponding class labels, respectively. Usually, clients share a model with the same architecture and hyperparameters. This model is parameterized by learnable weights and input features . The objective function of FedAvg (McMahan et al. 2017) is:
| (1) |
where is the global model’s parameters, denotes the number of clients, is the total number of instances over all clients, is the shared model, and is a general definition of any supervised learning task (e.g., a cross-entropy loss).
However, in a real-world FL environment, each client may represent a mobile phone with a specific user behavior pattern or a sensor deployed in a particular location, leading to statistical and/or model heterogeneous environment. In the statistical heterogeneity setting, varies across clients, indicating heterogeneous input/output space for and . For example, on different clients can be the data distributions over different subsets of classes. In the model heterogeneity setting, varies across clients, indicating different model architectures and hyperparameters. For the -th client, the training procedure is to minimize the loss as defined below:
| (2) |
Most existing methods cannot well handle the heterogeneous settings above. In particular, the fact that has a different model architecture would cause to have a different format and size. Thus, the global model’s parameter cannot be optimized by averaging . To tackle this challenge, we propose to communicate and aggregate prototypes in FL.
Prototype-Based Aggregation Setting
Heterogeneous FL focuses on the robustness to tackle heterogeneous input/output spaces, distributions and model architectures. For example, the datasets and on two clients and may take different statistical distributions of labels. This is common for a photo classification APP installed on mobile clients, where the server needs to recognize many classes , while each client only needs to recognize a few classes that constitute a subset of . The class set can vary across clients, though there are overlaps.
In general, the deep learning-based models comprise two parts: (1) representation layers (a.k.a. embedding functions) to transform the input from the original feature space to the embedding space; and (2) decision layers to make a classification decision for a given learning task.
Representation layers
The embedding function of the -th client is parameterized by . We denote as the embeddings of .
Decision layers
Given a supervised learning task, a prediction for can be generated by the function parameterized by . So, the labelling function can be written as , and we use to represent for short.
Prototype
We define a prototype to represent the -th class in . For the -th client, the prototype is the mean value of the embedding vectors of instances in class ,
| (3) |
where , a subset of the local dataset , is comprised of training instances belonging to the -th class.
Prototype-based model inference
In the inference stage of the learning task, we can simply predict the label to an instance by measuring the L2 distance between the instance’s representational vector and the prototype as follows:
| (4) |
Methodology
We propose a solution for heterogeneous FL that uses prototypes as the key component for exchanging information across the server and the clients.
An overview of the proposed framework is shown in Figure 1. The central server receives local prototype sets from local clients, and then aggregates the prototypes by averaging them. In the heterogeneous FL setting, these prototype sets overlap but are not the same. Taking the MNIST dataset as an example, the first client is to recognize the digits , while another client is to recognize the digits . These are two different handwritten digits set; nonetheless, there is an overlap. The server automatically aggregates prototypes from the overlapping class space across the clients.
Using prototypes in FL, we do not need to exchange gradients or model parameters, which means that the proposed solution can tackle heterogeneous model architectures. Moreover, the prototype-based FL does not require each client to provide the same classes, meaning the heterogeneous class spaces are well supported. Thus, heterogeneity challenges in FL can be addressed.
Optimization Objective
The objective of FedProto is to solve a joint optimization problem on a distributed network. FedProto applies prototype-based communication, which allows a local model to align its prototypes with other local models while minimizing the sum of loss for all clients’ local learning tasks. The objective of federated prototype learning across heterogeneous clients can be formulated as
| (5) | |||
where is the loss of supervised learning (as defined in Eq. (2)) and is a regularization term that measures the distance (we use L2 distance) between a local prototype and the corresponding global prototypes . is the total number of instances over all clients, and is the number of instances belonging to class over all clients.
The optimization problem can be addressed by alternate minimization that iterates the following two steps: (1) minimization w.r.t. each with fixed; and (2) minimization w.r.t. with all fixed. In a distributed setting, step (1) reduces to conventional supervised learning on each client using its local data, while step (2) aggregates local prototypes from local clients on the server end. Further details concerning these two steps can be seen in Algorithm 1.
Input:
Server executes:
LocalUpdate:
Global Prototype Aggregation
Given the data and model heterogeneity in the participating clients, the optimal model parameters for each client are not the same. This means that gradient-based communication cannot sufficiently provide useful information to each client. However, the same label space allows the participating clients to share the same embedding space and information can be efficiently exchanged across heterogeneous clients by aggregating prototypes according to the classes they belong to.
Given a class , the server receives prototypes from a set of clients that have class . A global prototype for class is generated after the prototype aggregating operation,
| (6) |
where denotes the prototype of class from client , and denotes the set of clients that have class .
Local Model Update
The client needs to update the local model to generate a consistent prototype across the clients. To this end, a regularization term is added to the local loss function, enabling the local prototypes to approach global prototypes while minimizing the loss of the classification error. In particular, the loss function is defined as follows:
| (7) |
where is an importance weight, and is the regularization term that can be defined as:
| (8) |
where is a distance metric of local generated prototypes and global aggregated prototypes . The distance measurement can take a variety of forms, such as L1 distance, L2 distance, and earth mover’s distance.
Convergence Analysis
We provide insights into the convergence analysis for FedProto. We denote the local objective function defined in Eq. 7 as with a subscript indicating the number of iterations and make the following assumptions similar to existing general frameworks (Wang et al. 2020b; Li et al. 2020).
Assumption 1.
(Lipschitz Smooth). Each local objective function is -Lipschitz smooth, which means that the gradient of local objective function is -Lipschitz continuous,
| (9) | ||||
This also implies the following quadratic bound,
| (10) | ||||
Assumption 2.
(Unbiased Gradient and Bounded Variance). The stochastic gradient is an unbiased estimator of the local gradient for each client. Suppose its expectation
| (11) |
and its variance is bounded by :
| (12) |
Assumption 3.
(Bounded Expectation of Euclidean norm of Stochastic Gradients).The expectation of the stochastic gradient is bounded by :
| (13) |
Assumption 4.
(Lipschitz Continuity). Each local embedding function is -Lipschitz continuous, that is,
| (14) | ||||
Based on the above assumptions, we present the theoretical results for the non-convex problem. The expected decrease per round is given in Theorem 1. We denote } as the local iteration, and as the global communication round. Moreover, represents the time step before prototype aggregation, and represents the time step between prototype aggregation and the first iteration of the current round.
Theorem 1.
(One-round deviation). Let Assumption 1 to 4 hold. For an arbitrary client, after every communication round, we have,
| (15) | ||||
Theorem 1 indicates the deviation bound of the local objective function for an arbitrary client after each communication round. Convergence can be guaranteed when there is a certain expected one-round decrease, which can be achieved by choosing appropriate and .
Corollary 1.
(Non-convex FedProto convergence). The loss function of an arbitrary client monotonously decreases in every communication round when
| (16) |
where , and
| (17) |
Thus, the loss function converges.
Corollary 1 is to ensure the expected deviation of to be negative, so the loss function converges. It can guide the choice of appropriate values for the learning rate and the importance weight to guarantee the convergence.
Theorem 2.
(Non-convex convergence rate of FedProto). Let Assumption 1 to 4 hold and where refers to the local optimum. For an arbitrary client, given any , after
| (18) |
communication rounds of FedProto, we have
| (19) |
if
Theorem 2 provides the convergence rate, which can confine the expected L2-norm of gradients to any bound, denoted as , after carefully selecting the number of communication rounds and hyperparameters including and . The smaller is, the larger is, which means that the tighter the bound is, more communication rounds is required. A detailed proof and analysis are given in Appendix B.
Discussion
In this section, we discuss the superiority of FedProto from three perspectives: model inference, communication efficiency, and privacy preserving.
Model Inference
Unlike many FL methods, the global model in FedProto is not a classifier but a set of class prototypes. When a new client is added to the network, one can initialize its local model with the representation layers of a pre-trained model, e.g. a ResNet18 on ImageNet, and random decision layers. Then, the local client will download the global prototypes of the classes covered in its local dataset and fine-tune the local model by minimizing the local objective. This can support new clients with novel model architectures and spend less time fine-tuning the model on heterogeneous datasets.
Communication Efficiency
Our proposed method only transmits prototypes between the server and clients. In general, the size of the prototypes is usually much smaller than the size of the model parameters. Taking MNIST as an example, the prototype size is 50 for each class, while the number of model parameters is 21,500. More details can be found in the experimental section.
Dataset Method Stdev Test Average Acc # of Comm Rounds # of Comm Params () MNIST Local 2 94.052.93 93.353.26 92.923.17 0 0 FeSEM (Xie et al. 2020) 2 95.263.48 97.062.72 96.312.41 150 430 FedProx (Li et al. 2020) 2 96.262.89 96.403.33 95.653.38 110 430 FedPer (Arivazhagan et al. 2019) 2 95.572.96 96.442.62 95.553.13 100 106 FedAvg (McMahan et al. 2017) 2 95.046.48 94.324.89 93.224.39 150 430 FedRep (Collins et al. 2021) 2 94.962.78 95.183.80 94.942.81 100 110 FedProto 2 97.130.30 96.800.41 96.700.29 100 4 FedProto-mh 2 97.070.50 96.650.31 96.220.36 100 4 FEMNIST Local 1 92.5010.42 91.165.64 87.918.44 0 0 FeSEM (Xie et al. 2020) 1 93.396.75 91.066.43 89.617.89 200 16,000 FedProx (Li et al. 2020) 1 94.535.33 90.716.24 91.337.32 300 16,000 FedPer (Arivazhagan et al. 2019) 1 93.475.44 90.227.63 87.739.64 250 102 FedAvg (McMahan et al. 2017) 1 94.505.29 91.395.23 90.957.22 300 16,000 FedRep (Collins et al. 2021) 1 93.365.34 91.415.89 89.986.88 200 102 FedProto 1 96.821.75 94.931.61 93.672.23 120 4 FedProto-mh 1 97.101.63 94.831.60 93.762.30 120 4 CIFAR10 Local 1 79.729.45 67.627.15 58.646.57 0 0 FeSEM (Xie et al. 2020) 1 80.193.31 76.403.23 74.173.51 120 235,000 FedProx (Li et al. 2020) 1 83.252.44 79.201.31 76.192.23 150 235,000 FedPer (Arivazhagan et al. 2019) 1 84.384.58 78.734.59 76.214.27 130 225,000 FedAvg (McMahan et al. 2017) 1 81.722.77 76.772.37 75.742.61 150 235,000 FedRep (Collins et al. 2021) 1 81.4410.48 76.937.46 73.367.04 110 225,000 FedProto 1 84.491.97 79.122.03 77.081.98 110 41 FedProto-mh 1 83.631.60 79.491.78 76.941.33 110 41
Privacy Preserving
The proposed FedProto requires the exchange of prototypes rather than model parameters between the server and the clients. This property brings benefits to FL in terms of privacy preserving. First, prototypes naturally protect the data privacy, because they are 1D-vectors generated by averaging the low-dimension representations of samples from the same class, which is an irreversible process. Second, attackers cannot reconstruct raw data from prototypes without the access to local models. Moreover, FedProto can be integrated with various privacy-preserving techniques to further enhance the reliability of the system.
Experiments
Training Setups
Datasets and local models
We implement the typical federated setting where each client owns its local data and transmits/receives information to/from the central server. We use three popular benchmark datasets: MNIST (LeCun 1998), FEMNIST (Caldas et al. 2018) and CIFAR10 (Krizhevsky, Hinton et al. 2009). We consider a multi-layer CNN which consists of 2 convolutional layers then 2 fully connected layers for both MNIST and FEMNIST, and ResNet18 (He et al. 2016) for CIFAR10.
Local tasks
Each client learns a supervised learning task. In particular, to illustrate the local task, we borrow the concept of -way -shot from few-shot learning where controls the number of classes and controls the number of training instances per class. To mimic the heterogeneous scenario, we randomly change the value of and in different clients. We define an average value for and , and then add a random noise to each user’s as well as . The purpose of the variance of is to control the heterogeneity of the class space, while the purpose of the variance of is to control the imbalance in data size.
Baselines of FL
We study the performance of FedProto under both the statistical and model heterogeneous settings (FedProto-mh) and make comparisons with baselines, including Local where an individual model is trained for each client without any communication with others, FedAvg (McMahan et al. 2017), FedProx (Li et al. 2020), FeSEM (Xie et al. 2020), FedPer (Arivazhagan et al. 2019), and FedRep (Collins et al. 2021).
Implementation Details
We implement FedProto and the baseline methods in PyTorch. We use 20 clients for all datasets and all clients are sampled in each communication round. The average size of each class in each client is set to be 100. For MNIST and FEMNIST dataset, our initial set of hyperparameters was taken directly from the default set of hyperparamters in (McMahan et al. 2017). For CIFAR10, ResNet18 pre-trained on ImageNet (Krizhevsky, Sutskever, and Hinton 2017) is used as the initial model. The initial average test accuracy of the pre-trained network on CIFAR10 is 27.55. A detailed setup including the choice of hyperparameters is given in Appendix A.
Performance in Non-IID Federated Setting
We compare FedProto with other baseline methods that are either classical FL methods or FL methods with an emphasis on statistical heterogeneity. All methods are adapted to fit this heterogeneous setting.
Statistical heterogeneity simulations
In our setting, we assume that all clients perform learning tasks with heterogeneous statistical distributions. In order to simulate different levels of heterogeneity, we fix the standard deviation to be 1 or 2, aiming to create heterogeneity in both class spaces and data sizes, which is common in real-world scenarios.
Model heterogeneity simulations
For the model heterogeneous setting, we consider minor differences in model architectures across clients. In MNIST and FEMNIST, the number of output channels in the convolutional layers is set to either 18, 20 or 22, while in CIFAR10, the stride of convolutional layers is set differently across different clients. This kind of model heterogeneity brings about challenges for model parameter averaging because the parameters in different clients are not always the same size.
The average test accuracy over all clients is reported in Table 1. It can be seen that FedProto achieves the highest accuracy and the least variance in most cases, ensuring uniformity among heterogeneous clients.
Communication efficiency
Communication costs have always been posed as a challenge in FL, considering several limitations in existing communication channels. Therefore, we also report the number of communication rounds required for convergence and the number of parameters communicated per round in Table 1. It can be seen that the number of parameters communicated per round in FedProto is much lower than in the case of FedAvg. Furthermore, FedProto requires the fewest communication rounds for the local optimization. This suggests that when the heterogeneity level is high across the clients, sharing more parameters does not always lead to better results. It is more important to identify which part to share in order to benefit the current system to a great extent. More performance results are shown in Appendix A.
Visualization of prototypes achieved by FedProto
We visualize the samples in MNIST test set by t-SNE (Van der Maaten and Hinton 2008). In Figure 2(a), small points in different colors represent samples in different classes, with large points representing corresponding global prototypes. In Figure 2(b), 2(c) and 2(d), the points in different colors refer to the representations of samples belonging to different classes. Better generalization means that there are more samples within the same class cluster in the same area, which can be achieved in a centralized setting, while better personalization means that it is easier to determine to which client the samples belong. It can be seen that samples within the same class but from various clients are close but separable in FedProto. This indicates that FedProto is more successful in achieving the balance between generalization and personalization, while other methods lacks either the generalization or the personalization ability.
Scalability of FedProto on varying number of samples
Figure 3 shows that FedProto can scale to scenarios with fewer samples available on clients. The test accuracy consistently decreases when there are fewer samples for training, but FedProto drops more slowly than FedAvg as a result of its adaptability and scalability on various data sizes.
FedProto under varying
Figure 4 shows the varying performance under different values of in Eq. (5). We tried a set of values selected from and reported the average test accuracy and proto distance loss with =3, =100 in FEMNIST dataset. The best value of is in this scenario. As increases, the proto distance loss (regularization term) decreases, while the average test accuracy experiences a sharp rise from =0 to =1 before a drop in the number of 6, demonstrating the efficacy of prototype aggregation.
Conclusion
In this paper, we propose a novel prototype aggregation-based FL method to tackle challenging FL scenarios with heterogeneous input/output spaces, data distributions, and model architectures. The proposed method collaboratively trains intelligent models by exchanging prototypes rather than gradients, which offers new insights for designing prototype-based FL. The effectiveness of the proposed method has been comprehensively analyzed from both theoretical and experimental perspectives.
References
- Arivazhagan et al. (2019) Arivazhagan, M. G.; Aggarwal, V.; Singh, A. K.; and Choudhary, S. 2019. Federated learning with personalization layers. arXiv preprint arXiv:1912.00818.
- Babenko and Lempitsky (2015) Babenko, A.; and Lempitsky, V. 2015. Aggregating local deep features for image retrieval. In Proceedings of the IEEE international conference on computer vision, 1269–1277.
- Cai et al. (2020) Cai, T.; Li, J.; Mian, A. S.; Sellis, T.; Yu, J. X.; et al. 2020. Target-aware holistic influence maximization in spatial social networks. IEEE Transactions on Knowledge and Data Engineering.
- Caldas et al. (2018) Caldas, S.; Duddu, S. M. K.; Wu, P.; Li, T.; Konečnỳ, J.; McMahan, H. B.; Smith, V.; and Talwalkar, A. 2018. Leaf: A benchmark for federated settings. arXiv: 1812.01097.
- Chen et al. (2020) Chen, C.; Zhang, J.; Tung, A. K.; Kankanhalli, M.; and Chen, G. 2020. Robust federated recommendation system. arXiv preprint arXiv:2006.08259.
- Collins et al. (2021) Collins, L.; Hassani, H.; Mokhtari, A.; and Shakkottai, S. 2021. Exploiting Shared Representations for Personalized Federated Learning. International Conference on Machine Learning.
- Deng, Kamani, and Mahdavi (2020) Deng, Y.; Kamani, M. M.; and Mahdavi, M. 2020. Adaptive Personalized Federated Learning. arXiv:2003.13461.
- Dvornik, Schmid, and Mairal (2020) Dvornik, N.; Schmid, C.; and Mairal, J. 2020. Selecting relevant features from a universal representation for few-shot classification.
- Fallah, Mokhtari, and Ozdaglar (2020) Fallah, A.; Mokhtari, A.; and Ozdaglar, A. 2020. Personalized Federated Learning with Theoretical Guarantees: A Model-Agnostic Meta-Learning Approach. In Advances in Neural Information Processing Systems.
- Ghosh et al. (2020) Ghosh, A.; Chung, J.; Yin, D.; and Ramchandran, K. 2020. An Efficient Framework for Clustered Federated Learning. In Advances in Neural Information Processing Systems.
- He, Annavaram, and Avestimehr (2020) He, C.; Annavaram, M.; and Avestimehr, S. 2020. FedNAS: Federated Deep Learning via Neural Architecture Search. In Proceedings of the IEEE conference on computer vision and pattern recognition.
- He et al. (2020) He, C.; Li, S.; So, J.; Zhang, M.; Wang, H.; Wang, X.; Vepakomma, P.; Singh, A.; Qiu, H.; Shen, L.; et al. 2020. Fedml: A research library and benchmark for federated machine learning. arXiv:2007.13518.
- He et al. (2016) He, K.; Zhang, X.; Ren, S.; and Sun, J. 2016. Deep residual learning for image recognition. In Proceedings of the IEEE conference on computer vision and pattern recognition, 770–778.
- Hoang et al. (2019) Hoang, M.; Hoang, N.; Low, B. K. H.; and Kingsford, C. 2019. Collective model fusion for multiple black-box experts. In International Conference on Machine Learning, 2742–2750. PMLR.
- Hoang et al. (2020) Hoang, N.; Lam, T.; Low, B. K. H.; and Jaillet, P. 2020. Learning Task-Agnostic Embedding of Multiple Black-Box Experts for Multi-Task Model Fusion. In International Conference on Machine Learning, 4282–4292. PMLR.
- Jeong et al. (2018) Jeong, E.; Oh, S.; Kim, H.; Park, J.; Bennis, M.; and Kim, S.-L. 2018. Communication-efficient on-device machine learning: Federated distillation and augmentation under non-IID private data. In Advances in Neural Information Processing Systems.
- Jiang, Ji, and Long (2020) Jiang, J.; Ji, S.; and Long, G. 2020. Decentralized knowledge acquisition for mobile internet applications. World Wide Web, 1–17.
- Kairouz et al. (2019) Kairouz, P.; McMahan, H. B.; Avent, B.; Bellet, A.; Bennis, M.; Bhagoji, A. N.; et al. 2019. Advances and open problems in federated learning. arXiv:1912.04977.
- Krizhevsky, Hinton et al. (2009) Krizhevsky, A.; Hinton, G.; et al. 2009. Learning multiple layers of features from tiny images.
- Krizhevsky, Sutskever, and Hinton (2017) Krizhevsky, A.; Sutskever, I.; and Hinton, G. E. 2017. Imagenet classification with deep convolutional neural networks. Communications of the ACM, 60(6): 84–90.
- LeCun (1998) LeCun, Y. 1998. The MNIST database of handwritten digits. http://yann. lecun. com/exdb/mnist/.
- Li and Wang (2020) Li, D.; and Wang, J. 2020. Fedmd: Heterogenous federated learning via model distillation. In Advances in Neural Information Processing Systems.
- Li, He, and Song (2021) Li, Q.; He, B.; and Song, D. 2021. Model-Contrastive Federated Learning. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, 10713–10722.
- Li et al. (2020) Li, T.; Sahu, A. K.; Zaheer, M.; Sanjabi, M.; Talwalkar, A.; and Smith, V. 2020. Federated optimization in heterogeneous networks. MLSys.
- Li et al. (2021) Li, Z.; Wang, X.; Li, J.; and Zhang, Q. 2021. Deep attributed network representation learning of complex coupling and interaction. Knowledge-Based Systems, 212: 106618.
- Liang et al. (2020) Liang, P. P.; Liu, T.; Ziyin, L.; Salakhutdinov, R.; and Morency, L.-P. 2020. Think Locally, Act Globally: Federated Learning with Local and Global Representations. Advances in Neural Information Processing Systems.
- Lin et al. (2020) Lin, T.; Kong, L.; Stich, S. U.; and Jaggi, M. 2020. Ensemble Distillation for Robust Model Fusion in Federated Learning. In Advances in Neural Information Processing Systems.
- Liu et al. (2020) Liu, L.; Hamilton, W. L.; Long, G.; Jiang, J.; and Larochelle, H. 2020. A Universal Representation Transformer Layer for Few-Shot Image Classification. In International Conference on Learning Representations.
- Liu et al. (2021a) Liu, Y.; Pan, S.; Jin, M.; Zhou, C.; Xia, F.; and Yu, P. S. 2021a. Graph self-supervised learning: A survey. arXiv preprint arXiv:2103.00111.
- Liu et al. (2021b) Liu, Y.; Pan, S.; Wang, Y. G.; Xiong, F.; Wang, L.; and Lee, V. 2021b. Anomaly Detection in Dynamic Graphs via Transformer. IEEE Transactions on Knowledge and Data Engineering.
- Long et al. (2021) Long, G.; Shen, T.; Tan, Y.; Gerrard, L.; Clarke, A.; and Jiang, J. 2021. Federated learning for privacy-preserving open innovation future on digital health. In Humanity Driven AI. Springer.
- Long et al. (2020) Long, G.; Tan, Y.; Jiang, J.; and Zhang, C. 2020. Federated Learning for Open Banking. In Federated Learning, 240–254. Springer.
- Luo et al. (2019) Luo, J.; Wu, X.; Luo, Y.; Huang, A.; Huang, Y.; Liu, Y.; and Yang, Q. 2019. Real-world image datasets for federated learning. arXiv:1910.11089.
- Mansour et al. (2020) Mansour, Y.; Mohri, M.; Ro, J.; and Suresh, A. T. 2020. Three approaches for personalization with applications to federated learning. arXiv:2002.10619.
- McMahan et al. (2017) McMahan, H. B.; Moore, E.; Ramage, D.; et al. 2017. Communication-efficient learning of deep networks from decentralized data. AISTATS.
- Michieli and Ozay (2021) Michieli, U.; and Ozay, M. 2021. Prototype Guided Federated Learning of Visual Feature Representations. arXiv preprint arXiv:2105.08982.
- Rieke et al. (2020) Rieke, N.; Hancox, J.; Li, W.; Milletari, F.; Roth, H. R.; Albarqouni, S.; Bakas, S.; Galtier, M. N.; Landman, B. A.; Maier-Hein, K.; et al. 2020. The future of digital health with federated learning. NPJ digital medicine, 3(1): 1–7.
- Sattler, Müller, and Samek (2020) Sattler, F.; Müller, K.-R.; and Samek, W. 2020. Clustered federated learning: Model-agnostic distributed multitask optimization under privacy constraints. IEEE transactions on neural networks and learning systems.
- Simonyan and Zisserman (2014) Simonyan, K.; and Zisserman, A. 2014. Two-Stream Convolutional Networks for Action Recognition in Videos. In Advances in Neural Information Processing Systems, 568–576.
- Singh et al. (2020) Singh, I.; Zhou, H.; Yang, K.; Ding, M.; Lin, B.; and Xie, P. 2020. Differentially-private federated neural architecture search. In FL-International Conference on Machine Learning Workshop.
- Snell, Swersky, and Zemel (2017) Snell, J.; Swersky, K.; and Zemel, R. 2017. Prototypical Networks for Few-shot Learning. Advances in Neural Information Processing Systems, 30: 4077–4087.
- Tan et al. (2021) Tan, A. Z.; Yu, H.; Cui, L.; and Yang, Q. 2021. Towards personalized federated learning. arXiv preprint arXiv:2103.00710.
- Van der Maaten and Hinton (2008) Van der Maaten, L.; and Hinton, G. 2008. Visualizing data using t-SNE. Journal of machine learning research, 9(11).
- Wang et al. (2020a) Wang, H.; Yurochkin, M.; Sun, Y.; Papailiopoulos, D.; and Khazaeni, Y. 2020a. Federated Learning with Matched Averaging. In International Conference on Learning Representations.
- Wang et al. (2020b) Wang, J.; Liu, Q.; Liang, H.; Joshi, G.; and Poor, H. V. 2020b. Tackling the Objective Inconsistency Problem in Heterogeneous Federated Optimization. Advances in neural information processing systems.
- Wieting et al. (2015) Wieting, J.; Bansal, M.; Gimpel, K.; and Livescu, K. 2015. Towards universal paraphrastic sentence embeddings. arXiv:1511.08198.
- Xie et al. (2020) Xie, M.; Long, G.; Shen, T.; Wang, X.; Tianyi, Z.; and Jiang, J. 2020. Multi-center Federated Learning. arXiv:2005.01026.
- Xu et al. (2020) Xu, J.; Glicksberg, B. S.; Su, C.; Walker, P.; Bian, J.; and Wang, F. 2020. Federated learning for healthcare informatics. Journal of Healthcare Informatics Research, 1–19.
- Xue et al. (2021) Xue, G.; Zhong, M.; Li, J.; Chen, J.; Zhai, C.; and Kong, R. 2021. Dynamic network embedding survey. arXiv preprint arXiv:2103.15447.
- Yang et al. (2021) Yang, Y.; Guan, Z.; Li, J.; Zhao, W.; Cui, J.; and Wang, Q. 2021. Interpretable and efficient heterogeneous graph convolutional network. IEEE Transactions on Knowledge and Data Engineering.
- Zheng et al. (2021) Zheng, Y.; Jin, M.; Liu, Y.; Chi, L.; Phan, K. T.; and Chen, Y.-P. P. 2021. Generative and Contrastive Self-Supervised Learning for Graph Anomaly Detection. IEEE Transactions on Knowledge and Data Engineering.
- Zhu, Zhang, and Jin (2020) Zhu, H.; Zhang, H.; and Jin, Y. 2020. From federated learning to federated neural architecture search: a survey. Complex & Intelligent Systems, 1–19.
- Zhu, Liu, and Han (2019) Zhu, L.; Liu, Z.; and Han, S. 2019. Deep Leakage from Gradients. In Advances in Neural Information Processing Systems, 14774–14784.
We present the related supplements in following sections.
Experimental Details and Extra Results
Experimental Details
Local clients are trained by SGD optimizer, with a learning rate of and momentum of . Regarding the crucial hyperparameter , we tune the best from a limited candidate set by grid search. The best values for MNIST, FEMNIST and CIFAR10 are , and , respectively. The number of local epochs and local batch size are set to be 1 and 8, respectively, for all datasets. The heterogeneity level of clients is controlled by the standard deviation of . The higher this is, the more heterogeneous the clients are.
Extra Results
The complete experimental results show the performance of FedProto and FedProto-mh on three benchmark datasets MNIST, FEMNIST, and CIFAR10. Compared with existing FL methods, FedProto yields higher test accuracy while resulting in lower communication costs under different heterogeneous settings. Additionally, it can be used in model heterogeneous scenarios and achieves performance similar to that in homogeneous scenarios.
For MNIST, we evaluate local test sets and report the evaluation results in Table 2. It appears that FedProto achieves strong performance with low communication cost. The local average test accuracy of FedProto is greater than for the FeSEM, FedProx, FedPer and FedAvg algorithms in all the settings.
For FEMNIST, the evaluation results are reported in Table 3. We consider the standard deviation of to be 1 and 2. The results show that, for FedProto, the variance of the accuracy across clients is much smaller than for other FL methods, thus ensuring uniformity among heterogeneous clients. FedProto allows us to better utilize the local FEMNIST dataset distribution while using around of the total parameters communicated.
For CIFAR10, as can be seen in Table 4, FedProto converges faster in the presence of heterogeneity in most cases. In FedProto and FedProto-mh, the number of parameters communicated per round is much lower than the baseline methods, meaning greatly reduced communication costs.
| Dataset | Method | Stdev of | Test Average Acc | # of Comm Rounds | # of Comm Params () | ||
|---|---|---|---|---|---|---|---|
| MNIST | Local | 2 | 94.052.93 | 93.353.26 | 92.923.17 | 0 | 0 |
| 3 | 93.443.57 | 94.242.49 | 93.972.97 | ||||
| FeSEM | 2 | 95.263.48 | 97.062.72 | 96.312.41 | 150 | 430 | |
| 3 | 96.403.35 | 95.823.94 | 95.982.46 | ||||
| FedProx | 2 | 96.262.89 | 96.403.33 | 95.653.38 | 110 | 430 | |
| 3 | 96.653.28 | 95.253.73 | 95.342.85 | ||||
| FedPer | 2 | 95.572.96 | 96.442.62 | 95.553.13 | 100 | 106 | |
| 3 | 96.572.65 | 95.932.76 | 96.072.80 | ||||
| FedAvg | 2 | 91.406.48 | 94.324.89 | 93.224.39 | 150 | 430 | |
| 3 | 94.574.91 | 91.996.89 | 92.193.97 | ||||
| FedRep | 2 | 94.962.78 | 95.183.80 | 94.942.81 | 100 | 110 | |
| 3 | 95.013.92 | 95.552.79 | 95.382.97 | ||||
| FedProto | 2 | 97.130.30 | 96.800.41 | 96.700.29 | 100 | 4 | |
| 3 | 96.710.43 | 96.870.28 | 96.470.23 | ||||
| FedProto-mh | 2 | 97.070.50 | 96.650.31 | 96.220.36 | 100 | 4 | |
| 3 | 96.480.43 | 96.840.33 | 95.560.31 | ||||
| Dataset | Method | Stdev of | Test Average Acc | # of Comm Rounds | # of Comm Params () | ||
|---|---|---|---|---|---|---|---|
| FEMNIST | Local | 1 | 92.5010.42 | 91.165.64 | 87.918.44 | 0 | 0 |
| 2 | 92.116.02 | 90.346.42 | 89.706.33 | ||||
| FeSEM | 1 | 93.396.75 | 91.066.43 | 89.617.89 | 200 | 16,000 | |
| 2 | 94.194.90 | 93.524.47 | 90.776.70 | ||||
| FedProx | 1 | 94.535.33 | 90.716.24 | 91.337.32 | 300 | 16,000 | |
| 2 | 93.495.30 | 93.745.02 | 89.496.74 | ||||
| FedPer | 1 | 93.475.44 | 90.227.63 | 87.739.64 | 250 | 102 | |
| 2 | 92.276.16 | 91.996.33 | 87.548.14 | ||||
| FedAvg | 1 | 94.505.29 | 91.395.23 | 90.957.22 | 300 | 16,000 | |
| 2 | 94.134.92 | 93.025.77 | 89.806.94 | ||||
| FedRep | 1 | 93.365.34 | 91.415.89 | 89.986.88 | 200 | 102 | |
| 2 | 92.285.40 | 91.567.02 | 88.236.97 | ||||
| FedProto | 1 | 96.821.75 | 94.931.61 | 93.672.23 | 120 | 4 | |
| 2 | 94.931.29 | 94.691.50 | 93.032.50 | ||||
| FedProto-mh | 1 | 97.101.63 | 94.831.60 | 93.762.30 | 120 | 4 | |
| 2 | 95.331.30 | 94.981.69 | 92.942.34 | ||||
| Dataset | Method | Stdev of | Test Average Acc | # of Comm Rounds | # of Comm Params () | ||
|---|---|---|---|---|---|---|---|
| CIFAR10 | Local | 1 | 79.729.45 | 67.627.15 | 58.646.57 | 0 | 0 |
| 2 | 68.159.88 | 61.0311.83 | 58.8112.90 | ||||
| FeSEM | 1 | 80.193.31 | 76.403.23 | 74.173.51 | 120 | 2.35 | |
| 2 | 76.124.15 | 72.113.48 | 70.893.39 | ||||
| FedProx | 1 | 83.252.44 | 79.201.31 | 76.192.23 | 150 | 2.35 | |
| 2 | 79.832.35 | 72.561.90 | 71.392.36 | ||||
| FedPer | 1 | 84.384.58 | 78.734.59 | 76.214.27 | 130 | 2.25 | |
| 2 | 84.514.39 | 73.314.76 | 72,434.55 | ||||
| FedAvg | 1 | 81.722.77 | 76.772.37 | 75.742.61 | 150 | 2.35 | |
| 2 | 78.992.34 | 72.732.58 | 70.932.82 | ||||
| FedRep | 1 | 81.4410.48 | 76.937.46 | 73.367.04 | 110 | 2.25 | |
| 2 | 76.7011.79 | 73.5411.42 | 70.308.00 | ||||
| FedProto | 1 | 84.491.97 | 79.122.03 | 77.081.98 | 110 | 4.10 | |
| 2 | 81.751.39 | 74.981.61 | 71.171.29 | ||||
| FedProto-mh | 1 | 83.631.60 | 79.491.78 | 76.941.33 | 110 | 4.10 | |
| 2 | 79.901.08 | 75.781.05 | 72.671.09 | ||||
Convergence Analysis for FedProto
Additional Notation
Here, additional variables are introduced to better represent the process of local model update. Let be the embedding function of the -th client, which can be different regarding to different clients. and represent the dimension of the input and the prototype, respectively. They should be the same for all clients. is the decision function for all clients, in which represents the dimension of output . So the labelling function can be written as , and sometimes we use to represent for short. In the theoretical analysis, we omit the label of prototype for convenience, which does not affect the proof. We also use to represent the weight of the prototype for -th client, and to represent the weight of the loss function for the -th client for short.
Therefore, the local loss function of client can be written as:
| (1) |
in which the global prototype
| (2) |
with
| (3) |
and
| (4) |
and it is a constant in , changing every communication round, which makes the convergence analysis complex.
As for the iteration notation system, we use to represent the communication round, } to represent the local iterations. There are local iterations in total, so refers to the -th local iteration in the communication round . Moreover, represents the time step before prototype aggregation at the server, and represents the time step between prototype aggregation at the server and starting the first iteration on the local model.
Assumptions
Assumption 1.
(Lipschitz Smooth). Each local objective function is -Lipschitz smooth, which also means the gradient of local objective function is -Lipschitz continuous,
| (5) | ||||
which implies the following quadratic bound,
| (6) |
Assumption 2.
(Unbiased Gradient and Bounded Variance). The stochastic gradient is an unbiased estimator of the local gradient for each client. Suppose its expectation
| (7) |
and its variance is bounded by :
| (8) |
Assumption 3.
(Bounded Expectation of Euclidean norm of Stochastic Gradients).The expectation of the stochastic gradient is bounded by :
| (9) |
Assumption 4.
(Lipschitz Continuity). Each local embedding function is -Lipschitz continuous, that is,
| (10) |
Key Lemmas
Lemma 1.
Proof.
Due to the fact that this lemma is for an arbitrary client, so client notation is omitted. Let , then
| (12) |
where (a) follows from the quadratic -Lipschitz smooth bound in Assumption 1. Taking expectation of both sides of the above equation on the random variable , we have
| (13) | ||||
| (14) | ||||
| (15) | ||||
| (16) | ||||
| (17) |
where (b) follows from Assumption 2, (c) follows from , (d) follows from Assumption 2. Take expectation of on both sides. Then, by telescoping of steps, we have,
| (18) |
∎
Lemma 2.
Proof.
| (20) | ||||
| (21) | ||||
| (22) | ||||
| (23) | ||||
| (24) | ||||
| (25) | ||||
| (26) | ||||
| (27) | ||||
| (28) | ||||
| (29) | ||||
| (30) |
Take expectations of random variable on both sides, then
| (32) | ||||
| (33) |
where (a) follows from the definition of local loss function in Eq. 1, (b) follows from , (c) follows from the definition of global prototype in Eq. 2, (d) follows from the definition of local prototype in Eq. 4, (e) and (h) follow from , (f) follows from -Lipschitz continuity in Assumption 4, (g) follows from the fact that is a subset of , (i) follows from Assumption 3. ∎
Theorems
Theorem 1.
(One-round deviation). Let Assumption 1 to 4 hold. For an arbitrary client, after every communication round, we have,
| (34) |
Corollary 1.
(Non-convex FedProto convergence). The loss function of arbitrary client monotonously decreases in every communication round when
| (35) |
and
| (36) |
Thus, the loss function converges.
Theorem 2.
(Non-convex convergence rate of FedProto). Let Assumption 1 to 4 hold and . For an arbitrary client, given any , after
| (37) |
communication rounds of FedProto, we have
| (38) |
if
| (39) |
and
| (40) |
Completing the Proof of Theorem 1 and Corollary 1
Completing the Proof of Theorem 2
Proof.
Take expectation of on both sides in Eq. 34, then telescope considering the communication round from to with the timestep from to in each communication round, we have
| (46) |
Given any , let
| (47) |
that is
| (48) |
Let . Since , the above equation holds when
| (49) |
that is
| (50) |
So, we have
| (51) |
when
| (52) |
and
| (53) |
∎